# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hysop.constants import Precision
from hysop.tools.htypes import check_instance, first_not_None, to_tuple
from hysop.symbolic.array import (
OpenClSymbolicArray,
OpenClSymbolicBuffer,
OpenClSymbolicNdBuffer,
)
from hysop.operator.base.custom_symbolic_operator import SymbolicExpressionParser
from hysop.backend.device.opencl.opencl_env import OpenClEnvironment
from hysop.backend.device.opencl.opencl_kernel_config import OpenClKernelConfig
from hysop.backend.device.opencl.opencl_symbolic import (
OpenClSymbolic,
OpenClAutotunableCustomSymbolicKernel,
)
[docs]
class OpenClElementwiseKernelGenerator:
def __init__(self, cl_env, kernel_config=None, user_build_options=None):
kernel_config = first_not_None(kernel_config, OpenClKernelConfig())
user_build_options = to_tuple(first_not_None(user_build_options, ()))
check_instance(cl_env, OpenClEnvironment)
check_instance(kernel_config, OpenClKernelConfig)
check_instance(user_build_options, tuple)
precision = kernel_config.precision
if precision == Precision.SAME:
precision = Precision.DEFAULT
float_dump_mode = kernel_config.float_dump_mode
use_short_circuit_ops = kernel_config.use_short_circuit_ops
unroll_loops = kernel_config.unroll_loops
autotuner_config = kernel_config.autotuner_config
typegen = cl_env.build_typegen(
precision=precision,
float_dump_mode=float_dump_mode,
use_short_circuit_ops=use_short_circuit_ops,
unroll_loops=unroll_loops,
)
build_options = set()
build_options.update(kernel_config.user_build_options)
build_options.update(cl_env.default_build_opts)
build_options.update(typegen.ftype_build_options())
build_options.update(user_build_options)
kernel_autotuner = OpenClAutotunableCustomSymbolicKernel(
cl_env=cl_env,
typegen=typegen,
build_opts=tuple(build_options),
autotuner_config=autotuner_config,
)
self._cl_env = cl_env
self._kernel_autotuner = kernel_autotuner
[docs]
def elementwise_kernel(self, name, *exprs, **kwds):
# 1) call_only_once means that the autotuner will stop at
# first successfull kernel build and exec.
# 2) By default we disable vectorization because user given expressions
# may not be safe to vectorize
check_instance(name, str)
assert len(exprs) > 0, exprs
queue = kwds.pop("queue", self._cl_env.default_queue)
call_only_once = kwds.pop("call_only_once", False)
disable_vectorization = kwds.pop("disable_vectorization", True)
force_volatile = kwds.pop("force_volatile", ())
max_candidates = kwds.pop("max_candidates", None)
compute_resolution = kwds.pop("compute_resolution", None)
debug = kwds.pop("debug", False)
if kwds:
msg = f"Unknown keyword arguments: {kwds.keys()}"
raise ValueError(msg)
expr_info = SymbolicExpressionParser.parse(
name, {}, *exprs, compute_resolution=compute_resolution
)
assert not expr_info.has_direction, expr_info
expr_info.compute_granularity = 0
expr_info.space_discretization = None
expr_info.time_integrator = None
expr_info.interpolation = None
expr_info.min_ghosts = {}
expr_info.min_ghosts_per_components = {}
expr_info.extract_obj_requirements()
expr_info.discretize_expressions(
input_dfields={}, output_dfields={}, force_symbolic_axes=True
)
expr_info.setup_expressions(None)
expr_info.check_arrays()
expr_info.check_buffers()
for var in force_volatile:
expr_info.is_volatile.add(var.varname)
kernel, args_dict, update_input_parameters = self._kernel_autotuner.autotune(
expr_info=expr_info,
queue=queue,
first_working=call_only_once,
disable_vectorization=disable_vectorization,
debug=debug,
max_candidates=max_candidates,
)
kl = kernel.build_launcher(**args_dict)
return (kl, update_input_parameters)
[docs]
def elementwise(self, name, *exprs, **kwds):
kernel, update_input_parameters = self.elementwise_kernel(name, *exprs, **kwds)
queue = kwds.pop("queue", self._cl_env.default_queue)
def call_kernel(
queue=queue, kernel=kernel, update_input_parameters=update_input_parameters
):
return kernel(queue=queue, **update_input_parameters())
return call_kernel
[docs]
@classmethod
def symbolic_buffers(cls, *names, **kwds):
return OpenClSymbolic.symbolic_buffers(*names, **kwds)
[docs]
@classmethod
def symbolic_ndbuffers(cls, *names, **kwds):
return OpenClSymbolic.symbolic_ndbuffers(*names, **kwds)
[docs]
@classmethod
def symbolic_arrays(cls, *names, **kwds):
return OpenClSymbolic.symbolic_arrays(*names, **kwds)
[docs]
@classmethod
def symbolic_tmp_scalars(cls, *names, **kwds):
return OpenClSymbolic.symbolic_tmp_scalars(*names, **kwds)
[docs]
@classmethod
def symbolic_constants(cls, *names, **kwds):
return OpenClSymbolic.symbolic_constants(*names, **kwds)
[docs]
@classmethod
def arrays_to_symbols(cls, *arrays, **kwds):
symbols = ()
for i, array in enumerate(arrays):
name = f"a{i}"
symbol = OpenClSymbolicArray(name=name, memory_object=array, **kwds)
symbols += (symbol,)
return symbols
[docs]
@classmethod
def arrays_to_ndbuffers(cls, *arrays, **kwds):
symbols = ()
for i, array in enumerate(arrays):
name = f"ab{i}"
symbol = OpenClSymbolicNdBuffer(name=name, memory_object=array, **kwds)
symbols += (symbol,)
return symbols
[docs]
@classmethod
def dfields_to_ndbuffers(cls, *dfields, **kwds):
symbols = ()
for dfield in dfields:
assert dfield.is_scalar
symbol = OpenClSymbolicNdBuffer(
name=dfield.name,
memory_object=dfield.sbuffer,
ghosts=dfield.ghosts,
**kwds,
)
symbols += (symbol,)
return symbols
[docs]
@classmethod
def buffer_to_symbols(cls, *buffers, **kwds):
symbols = ()
for i, buf in enumerate(buffers):
name = f"b{i}"
symbol = OpenClSymbolicBuffer(name=name, memory_object=buf, **kwds)
symbols += (symbol,)
return symbols